# import torch
# import torch.nn.functional as F
# from torch_geometric.datasets import Planetoid
# import torch_geometric.transforms as T
# from torch_geometric.nn import GINConv, Sequential, global_add_pool

# # Define the GIN model
# class GIN(torch.nn.Module):
#     def __init__(self, num_features, num_classes):
#         super(GIN, self).__init__()
#         dim = 32
#         self.conv1 = GINConv(Sequential(torch.nn.Linear(num_features, dim), torch.nn.ReLU(), torch.nn.Linear(dim, dim), torch.nn.ReLU()))
#         self.conv2 = GINConv(Sequential(torch.nn.Linear(dim, dim), torch.nn.ReLU(), torch.nn.Linear(dim, dim), torch.nn.ReLU()))
#         self.conv3 = GINConv(Sequential(torch.nn.Linear(dim, dim), torch.nn.ReLU(), torch.nn.Linear(dim, dim), torch.nn.ReLU()))
#         self.lin = torch.nn.Linear(dim, num_classes)

#     def forward(self, x, edge_index, batch):
#         x = F.relu(self.conv1(x, edge_index))
#         x = F.relu(self.conv2(x, edge_index))
#         x = F.relu(self.conv3(x, edge_index))
#         x = global_add_pool(x, batch)
#         x = F.dropout(x, p=0.5, training=self.training)
#         x = self.lin(x)
#         return F.log_softmax(x, dim=-1)

# # Load the dataset
# dataset = Planetoid(root='data/Cora', name='Cora', transform=T.NormalizeFeatures())
# data = dataset[0]

# # Define the device
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# # Instantiate the model and move it to the device
# model = GIN(num_features=dataset.num_features, num_classes=dataset.num_classes).to(device)

# # Define the optimizer and the learning rate
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# # Train the model
# model.train()
# for epoch in range(200):
#     optimizer.zero_grad()
#     out = model(data.x.to(device), data.edge_index.to(device), data.batch.to(device))
#     loss = F.nll_loss(out[dataset.train_mask], dataset.y[dataset.train_mask])
#     loss.backward()
#     optimizer.step()

# # Evaluate the model
# model.eval()
# out = model(data.x.to(device), data.edge_index.to(device), data.batch.to(device))
# pred = out.argmax(dim=-1)
# correct = float(pred.eq(data.y.to(device)).sum().item())
# acc = correct / len(data.y)
# print('Accuracy: {:.4f}'.format(acc))
